import numpy as np
from scipy.integrate import solve_ivp
import time
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

class SymbolGroundingSimulation:
    def __init__(self, dim=5, seed=42):
        """
        Initialize the simulation with parameters.
        
        Args:
            dim: Dimensionality of representation vectors
            seed: Random seed for reproducibility
        """
        np.random.seed(seed)
        self.dim = dim
        
        # Define parameters
        self.alpha = 1.0  # Strength of attractor
        self.gamma_min = 0.3  # Minimum damping coefficient
        self.eta_max = 0.1  # Maximum perturbation magnitude
        self.phi_scale = 0.05  # Scale for perturbation function
        
        # Initialize optimal representation (the "grounded" state)
        self.rep_X_star = np.random.randn(dim)
        self.rep_X_star = self.rep_X_star / np.linalg.norm(self.rep_X_star)
        # self.rep_X_star = 0.3 * self.M + 0.4 * self.S + 0.3 * self.E
        # self.rep_X_star = self.rep_X_star / np.linalg.norm(self.rep_X_star)
        
        # Episodic memory, sensory, and emotional components initialization
        self.M = np.random.randn(dim)  # Memory component
        self.S = np.random.randn(dim)  # Sensory component
        self.E = np.random.randn(dim)  # Emotional component
        
        # Normalize the components
        self.M = 0.3 * self.M / np.linalg.norm(self.M)
        self.S = 0.4 * self.S / np.linalg.norm(self.S)
        self.E = 0.3 * self.E / np.linalg.norm(self.E)
        
        # Combine to form the target representation (this is what experience would produce)
        self.rep_X_star = 0.3 * self.M + 0.4 * self.S + 0.3 * self.E
        self.rep_X_star = self.rep_X_star / np.linalg.norm(self.rep_X_star)
        
        # Set the convergence threshold (can be adjusted)
        self.convergence_threshold = 0.6
    
    def V(self, rep_X):
        """Internal potential function with minimum at rep_X_star."""
        return 0.5 * np.sum((rep_X - self.rep_X_star)**2)
    
    def gradient_V(self, rep_X):
        """Gradient of the internal potential function."""
        return rep_X - self.rep_X_star
    
    def V_ext(self, rep_X, t):
        """
        External potential function encoding sensorimotor constraints.
        Time-dependent to simulate changing embodied inputs.
        """
        # Time-varying weights for different components
        w_M = 0.3 + 0.1 * np.sin(0.1 * t)
        w_S = 0.4 + 0.1 * np.cos(0.2 * t)
        w_E = 0.3 + 0.05 * np.sin(0.3 * t)
        
        # Compute the weighted sum of distances from each component
        return 0.5 * (w_M * np.sum((rep_X - self.M)**2) + 
                      w_S * np.sum((rep_X - self.S)**2) + 
                      w_E * np.sum((rep_X - self.E)**2))
    
    def gradient_V_ext(self, rep_X, t):
        """Gradient of the external potential function."""
        w_M = 0.3 + 0.05 * np.sin(0.1 * t)
        w_S = 0.4 + 0.05 * np.cos(0.2 * t)
        w_E = 0.3 + 0.025 * np.sin(0.3 * t)
        
        return (w_M * (rep_X - self.M) + 
                w_S * (rep_X - self.S) + 
                w_E * (rep_X - self.E))
    
    def gamma(self, t):
        """
        Damping term that ensures dissipative dynamics.
        Models metabolic constraints that increase over time.
        """
        return self.gamma_min + 0.1 * (1 - np.exp(-0.1 * t))
    
    def phi(self, rep_X):
        """
        Function that modulates the effect of perturbations.
        Decreases as rep_X approaches the target state.
        """
        dist = np.linalg.norm(rep_X - self.rep_X_star)
        return self.phi_scale * dist
    
    def eta(self, t):
        """
        Bounded perturbation term (representing noise or exploration).
        """
        noise = np.random.randn(self.dim)
        return self.eta_max * noise / np.linalg.norm(noise)
    
    def dynamics(self, t, rep_X):
        """
        The full dynamics of the symbolic representation as specified in Lemma I.D.
        
        dRep_X/dt = -∇V(Rep_X) - γ(t)Rep_X + φ(Rep_X)η(t) + I_X(t)
        where I_X(t) = -∇V_ext(Rep_X)
        """
        grad_V = self.gradient_V(rep_X)
        damping = self.gamma(t) * rep_X
        perturbation = self.phi(rep_X) * self.eta(t)
        embodied_input = -self.gradient_V_ext(rep_X, t)
        
        # Calculate the derivative
        drep_X_dt = -grad_V - damping + perturbation + embodied_input
        
        return drep_X_dt
    
    def run_simulation(self, rep_X_init, t_span, dt=0.1):
        """
        Run the simulation and return the trajectory of Rep_X.
        
        Args:
            rep_X_init: Initial representation
            t_span: Time span (t_start, t_end)
            dt: Time step
        
        Returns:
            t_eval: Time points
            rep_X_trajectory: Values of Rep_X at each time point
        """
        t_eval = np.arange(t_span[0], t_span[1] + dt, dt)
        
        # Use scipy's ODE solver
        solution = solve_ivp(
            self.dynamics, 
            t_span, 
            rep_X_init, 
            t_eval=t_eval, 
            method='RK45'
        )
        
        return solution.t, solution.y.T
    
    def evaluate_lyapunov(self, rep_X_trajectory, t_eval):
        """
        Compute the Lyapunov function V(Rep_X) along the trajectory.
        
        Args:
            rep_X_trajectory: Values of Rep_X at each time point
            t_eval: Corresponding time points
        
        Returns:
            lyapunov_values: Values of V(Rep_X) at each time point
        """
        lyapunov_values = np.zeros(len(t_eval))
        
        for i, rep_X in enumerate(rep_X_trajectory):
            lyapunov_values[i] = self.V(rep_X)
        
        return lyapunov_values
    
    def analyze_convergence(self, rep_X_trajectory, t_eval):
        """
        Analyze convergence properties of the simulation.
        
        Args:
            rep_X_trajectory: Values of Rep_X at each time point
            t_eval: Corresponding time points
        
        Returns:
            distance_to_target: Distance to target at each time point
            lyapunov_values: Values of Lyapunov function at each time point
        """
        distance_to_target = np.zeros(len(t_eval))
        
        for i, rep_X in enumerate(rep_X_trajectory):
            distance_to_target[i] = np.linalg.norm(rep_X - self.rep_X_star)
        
        lyapunov_values = self.evaluate_lyapunov(rep_X_trajectory, t_eval)
        
        return distance_to_target, lyapunov_values
    
    def run_test(self, test_name, rep_X_init=None, t_span=(0, 90000), dt=0.1):
        """
        Run a specific test and print results.
        
        Args:
            test_name: Name of the test to run
            rep_X_init: Initial representation (if None, a random one is generated)
            t_span: Time span for simulation (increased to 200 by default)
            dt: Time step
        """
        if rep_X_init is None:
            rep_X_init = np.random.randn(self.dim)
            rep_X_init = rep_X_init / np.linalg.norm(rep_X_init)
        
        print(f"\n{'=' * 60}")
        print(f"TEST: {test_name}")
        print(f"{'=' * 60}")
        print(f"Running simulation for {t_span[1]} time units...")
        
        # Run simulation
        t_start = time.time()
        t_eval, rep_X_trajectory = self.run_simulation(rep_X_init, t_span, dt)
        t_end = time.time()
        
        # Analyze results
        distance_to_target, lyapunov_values = self.analyze_convergence(
            rep_X_trajectory, t_eval
        )
        
        # Print key metrics
        print(f"Simulation completed in {t_end - t_start:.2f} seconds")
        print(f"Initial distance to target: {distance_to_target[0]:.6f}")
        print(f"Final distance to target: {distance_to_target[-1]:.6f}")
        print(f"Reduction in distance: {(1 - distance_to_target[-1]/distance_to_target[0])*100:.2f}%")
        print(f"Initial Lyapunov value: {lyapunov_values[0]:.6f}")
        print(f"Final Lyapunov value: {lyapunov_values[-1]:.6f}")
        print(f"Reduction in Lyapunov function: {(1 - lyapunov_values[-1]/lyapunov_values[0])*100:.2f}%")
        
        # Check if convergence criteria are met (using an adjustable threshold)
        converged = distance_to_target[-1] < self.convergence_threshold
        print(f"Convergence status: {'Converged' if converged else 'Not yet converged'}")
        
        # Print interpretation based on actual results
        if test_name == "Basic Convergence":
            print("\nINSIGHTS:")
            if converged:
                print("- CONFIRMED: The representation vector converges towards the target state,")
                print("  demonstrating the stability properties described in Lemma I.D.")
            else:
                print("- PARTIAL SUPPORT: The representation shows movement towards the target,")
                print("  but hasn't fully converged within the simulation timeframe.")
            
            if lyapunov_values[-1] < lyapunov_values[0] * 0.5:
                print("- CONFIRMED: The Lyapunov function decreases significantly, supporting")
                print("  the dissipative nature of the dynamics.")
            else:
                print("- PARTIAL SUPPORT: The Lyapunov function decreases, but not as")
                print("  strongly as predicted by theory.")
            
        elif test_name == "Robustness to Perturbations":
            # Calculate stability metrics
            avg_fluctuation = np.mean(np.abs(np.diff(distance_to_target[-100:])))
            print("\nINSIGHTS:")
            print(f"- Average fluctuation in final phase: {avg_fluctuation:.6f}")
            
            if converged:
                print("- CONFIRMED: Despite strong perturbations (η term), the system converges")
                print("  to the target state, demonstrating exceptional robustness.")
            elif distance_to_target[-1] < distance_to_target[0] * 0.6:
                print("- PARTIAL SUPPORT: The system moves significantly towards the target despite")
                print("  strong perturbations, showing some robustness.")
            else:
                print("- NOT SUPPORTED: The strong perturbations prevent convergence, suggesting")
                print("  limits to the system's robustness under these conditions.")
            
        elif test_name == "Metabolic Effects":
            # Compare early vs. late convergence rates
            early_window = min(10, len(t_eval)//10)
            late_window = min(10, len(t_eval)//10) 
            
            early_rate = (distance_to_target[early_window] - distance_to_target[0]) / (t_eval[early_window] - t_eval[0])
            late_rate = (distance_to_target[-1] - distance_to_target[-late_window-1]) / (t_eval[-1] - t_eval[-late_window-1])
            
            print("\nINSIGHTS:")
            print(f"- Early convergence rate: {early_rate:.6f} units/time")
            print(f"- Late convergence rate: {late_rate:.6f} units/time")
            
            if converged:
                print("- CONFIRMED: Despite increased metabolic constraints (γ(t)),")
                print("  the system converges to the target state.")
            else:
                print("- PARTIAL SUPPORT: The system moves towards the target despite")
                print("  metabolic constraints, but hasn't fully converged.")
            
            if abs(late_rate) < abs(early_rate):
                print("- CONFIRMED: Convergence rate decreases over time due to metabolic")
                print("  constraints, as predicted by theory.")
            else:
                print("- NOT SUPPORTED: Metabolic constraints don't show the expected")
                print("  effect on convergence rate.")
            
        elif test_name == "Multimodal Integration":
            # Track the influence of different modalities
            M_influence = np.zeros(len(t_eval))
            S_influence = np.zeros(len(t_eval))
            E_influence = np.zeros(len(t_eval))
            
            for i, t in enumerate(t_eval):
                w_M = 0.3 + 0.1 * np.sin(0.1 * t)
                w_S = 0.4 + 0.1 * np.cos(0.2 * t)
                w_E = 0.3 + 0.05 * np.sin(0.3 * t)
                
                M_influence[i] = w_M
                S_influence[i] = w_S
                E_influence[i] = w_E
            
            print("\nINSIGHTS:")
            print(f"- Average influence of Memory (M): {np.mean(M_influence):.4f}")
            print(f"- Average influence of Sensory (S): {np.mean(S_influence):.4f}")
            print(f"- Average influence of Emotional (E): {np.mean(E_influence):.4f}")
            
            if converged:
                print("- CONFIRMED: The integration of memory, sensory, and emotional inputs")
                print("  successfully guides the representation to a stable attractor state.")
            else:
                print("- PARTIAL SUPPORT: Multimodal integration influences the trajectory,")
                print("  but hasn't fully converged the representation.")
            
        return t_eval, rep_X_trajectory, distance_to_target, lyapunov_values, converged
        

def run_all_tests():
    """Run a comprehensive set of tests demonstrating Lemma I.D."""
    sim = SymbolGroundingSimulation(dim=5)

    # Track the number of converged tests
    num_converged = 0
    print("\nInitializing the symbol grounding simulation...")

    # Store results for plotting
    all_results = []
    
    # Test 1: Basic convergence - longer time span    
    rep_X_init_1 = np.random.randn(sim.dim)
    rep_X_init_1 = rep_X_init_1 / np.linalg.norm(rep_X_init_1)
    nt = sim.run_test("Basic Convergence", rep_X_init_1, t_span=(0, 500))
    if nt[4]:
        num_converged += 1
    all_results.append(('Basic Convergence', nt))
    
    # Test 2: Robustness to perturbations (increase eta_max)
    sim.eta_max = 0.6  # Increase perturbation magnitude
    rep_X_init_2 = np.random.randn(sim.dim)
    rep_X_init_2 = rep_X_init_2 / np.linalg.norm(rep_X_init_2)
    nt = sim.run_test("Robustness to Perturbations", rep_X_init_2, t_span=(0, 550))
    if nt[4]:
        num_converged += 1
    all_results.append(('Robustness to Perturbations', nt))
    
    # Test 3: Metabolic effects (manipulate gamma)
    sim.eta_max = 0.1  # Restore original perturbation magnitude
    sim.gamma_min = 0.2  # Increase base metabolic damping
    rep_X_init_3 = np.random.randn(sim.dim)
    rep_X_init_3 = rep_X_init_3 / np.linalg.norm(rep_X_init_3)
    nt = sim.run_test("Metabolic Effects", rep_X_init_3, t_span=(0, 500))
    if nt[4]:
        num_converged += 1
    all_results.append(('Metabolic Effects', nt))
    
    # Test 4: Multimodal integration test (manipulate M, S, E)
    sim.gamma_min = 0.1  # Restore original damping    
    # Create a representation that's initially aligned with sensory but not memory/emotion
    rep_X_init_4 = 0.8 * sim.S + 0.2 * np.random.randn(sim.dim)
    rep_X_init_4 = rep_X_init_4 / np.linalg.norm(rep_X_init_4)    
    nt = sim.run_test("Multimodal Integration", rep_X_init_4, t_span=(0, 550))
    if nt[4]:
        num_converged += 1
    all_results.append(('Multimodal Integration', nt))    
    
    print("\n" + "=" * 60)
    print("SIMULATION SUMMARY AND RELEVANCE TO Lemma I.D")
    print("=" * 60)
    
    if num_converged == 4:
        print("ALL TESTS CONVERGED! This strongly supports Lemma I.D's predictions.")
    elif num_converged >= 2:
        print(f"{num_converged}/4 TESTS CONVERGED. This provides partial support for Lemma I.D.")
    else:
        print(f"Only {num_converged}/4 TESTS CONVERGED. The simulation parameters may need adjustment,")
        print("or Lemma I.D may require refinement under these particular conditions.")
    
    print("\nThe simulation tested the following aspects of Lemma I.D:")
    print("1. Symbolic representations (Rep_X) should converge to stable attractor states")
    print("   through the recursive integration of embodied inputs.")
    print("2. The dynamics follow the equation:")
    print("   dRep_X/dt = -∇V(Rep_X) - γ(t)Rep_X + φ(Rep_X)η(t) + I_X(t)")
    print("   where I_X(t) = -∇V_ext(Rep_X) encodes sensorimotor constraints.")
    print("3. The Lyapunov function V(Rep_X) should decrease over time if the")
    print("   system is dissipative and converges to a stable fixed point.")
    print("4. The integration of memory (M), sensory (S), and emotional (E)")
    print("   components should ground the symbolic representation in embodied experience.")
    print("5. Metabolic constraints, modeled by the damping term γ(t), should ensure")
    print("   energetic efficiency while maintaining representational stability.")    


if __name__ == "__main__":
    run_all_tests()